<html>

<head>
  <title>TensorFlow.js Model Benchmark</title>
  <link href="https://fonts.googleapis.com/css?family=Roboto" rel="stylesheet">
  <style>
    html,
    body {
      font-family: 'Roboto', sans-serif;
      font-size: 13px;
      display: flex;
      flex-direction: column;
      box-sizing: border-box;
      position: relative;
    }

    body {
      margin: 20px 100px;
    }

    h2 {
      margin-bottom: 30px;
    }

    #kernels {
      max-width: 750px;
    }

    #container {
      display: flex;
      flex-direction: row;
      flex-wrap: wrap;
    }

    .box {
      margin-right: 30px;
      margin-bottom: 30px;
    }

    .box pre {
      margin: 0;
      border: 1px solid #ccc;
      padding: 8px;
      font-size: 10px;
    }

    #trendline-container svg {
      overflow: visible;
      border-bottom: 1px solid #ccc;
      border-left: 1px solid #ccc;
    }

    #trendline-container .label {
      font-size: 14px;
      font-weight: bold;
    }

    #trendline-container path {
      fill: none;
      stroke: #222;
    }

    #trendline {
      position: relative;
      margin-top: 20px;
    }

    #trendline #yMax, #trendline #yMin {
      position: absolute;
      right: calc(100% + 6px);
      font-size: 11px;
      white-space: nowrap;
    }

    #trendline #yMin {
      bottom: 0;
    }

    #trendline #yMax {
      top: 0;
    }

    #modal-msg {
      border-radius: 5px;
      background-color: black;
      color: white;
      padding: 7px;
      top: 15px;
      left: 45%;
      display: none;
      position: absolute;
    }

    .table {
      margin-right: 30px;
      margin-bottom: 30px;
      border: 1px solid #ccc;
      border-collapse: collapse;
      border-spacing: 0;
    }

    .table tr {
      border-bottom: 1px solid #ddd;
    }

    .table tr:nth-child(even) {
      background-color: #f1f1f1;
    }

    .table th {
      font-weight: bold;
    }

    .table td,
    th {
      padding: 8px 8px;
      font-size: 13px;
      text-align: left;
      vertical-align: top;
    }

    .table td:first-child,
    th:first-child {
      padding-left: 16px;
    }
  </style>
  <script src="https://cdnjs.cloudflare.com/ajax/libs/dat-gui/0.7.2/dat.gui.min.js"></script>
</head>

<body>
  <h2>TensorFlow.js Model Benchmark</h2>
  <div id="modal-msg"></div>
  <div id="container">
    <div id="stats">
      <div class="box">
        <pre id="env"></pre>
      </div>
      <table class="table" id="timings">
        <thead>
          <tr>
            <th>Type</th>
            <th>Value</th>
          </tr>
        </thead>
        <tbody>
        </tbody>
      </table>
      <div class="box" id="trendline-container">
        <div class="label"></div>
        <div id="trendline">
          <div id="yMax"></div>
          <div id="yMin">0 ms</div>
          <svg><path></path></svg>
        </div>
      </div>
    </div>
    <table class="table" id="kernels">
      <thead>
        <tr>
          <th>Kernel</th>
          <th>Time(ms)</th>
          <th>Output</th>
          <th>GPUPrograms</th>
        </tr>
      </thead>
      <tbody></tbody>
    </table>
  </div>
  <script src="https://unpkg.com/@tensorflow/tfjs-core/dist/tf-core.js"></script>
  <script src="https://unpkg.com/@tensorflow/tfjs-layers/dist/tf-layers.js"></script>
  <script src="https://unpkg.com/@tensorflow/tfjs-converter/dist/tf-converter.js"></script>
  <script>
    'use strict';
    async function load() {
      //////////////////////////////////
      // Place model loading code here.
      //////////////////////////////////
      return await tf.loadFrozenModel(
        'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/tensorflowjs_model.pb',
        'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/weights_manifest.json');
    }

    const zeros = tf.zeros([1, 224, 224, 3]);
    function predict(model) {
      //////////////////////////////////
      // Place model prediction code here.
      //////////////////////////////////
      if (isAsync) {
        return model.executeAsync(zeros);
      }
      return model.predict(zeros);
    }
  </script>
  <script>
    'use strict';
    const state = {
      numRuns: 50,
    };
    const modalDiv = document.getElementById('modal-msg');
    const timeTable = document.querySelector('#timings tbody');
    const envDiv = document.getElementById('env');
    let model, isAsync;

    async function showMsg(message) {
      if (message != null) {
        modalDiv.innerHTML = message + '...';
        modalDiv.style.display = 'block';
      } else {
        modalDiv.style.display = 'none';
      }
      await tf.nextFrame();
      await tf.nextFrame();
    }

    function showVersions() {
      envDiv.innerHTML = JSON.stringify({
        core: tf.version_core,
        layers: tf.version_layers,
        converter: tf.version_converter
      }, null, 2);
    }

    async function showEnvironment() {
      await tf.time(() => tf.square(3).data());
      envDiv.innerHTML += `<br/>${JSON.stringify(tf.ENV.features, null, 2)}`;
    }

    function printTime(elapsed) {
      return elapsed.toFixed(1) + ' ms';
    }

    function printMemory(bytes) {
      if (bytes < 1024) {
        return bytes + ' B';
      } else if (bytes < 1024 * 1024) {
        return (bytes / 1024).toFixed(2) + ' KB';
      } else {
        return (bytes / (1024 * 1024)).toFixed(2) + ' MB';
      }
    }

    function appendRow(tbody, ...cells) {
      const tr = document.createElement('tr');
      cells.forEach(c => {
        const td = document.createElement('td');
        if (c instanceof HTMLElement) {
          td.appendChild(c);
        } else {
          td.textContent = c;
        }
        tr.appendChild(td);
      });
      tbody.appendChild(tr);
    }

    async function warmUpAndRecordTime() {
      await showMsg('Warming up');
      const start = performance.now();
      let res = predict(model);
      if (res instanceof Promise) {
        res = await res;
      }

      if (res instanceof tf.Tensor) {
        res.dataSync();
      }

      const elapsed = performance.now() - start;
      await showMsg(null);
      appendRow(timeTable, '1st inference', printTime(elapsed));
    }

    function sleep(timeMs) {
      return new Promise(resolve => setTimeout(resolve, timeMs));
    }

    async function loadAndRecordTime() {
      await showMsg('Loading the model');
      const start = performance.now();
      model = await load();
      isAsync = model.executor != null && model.executor.isControlFlowModel;

      const elapsed = performance.now() - start;
      await showMsg(null);
      appendRow(timeTable, 'Model load', printTime(elapsed));
    }

    async function measureAveragePredictTime() {
      document.querySelector("#trendline-container .label").textContent = `Inference times over ${state.numRuns} runs`;
      await showMsg(`Running predict ${state.numRuns} times`);
      const chartHeight = 150;
      const chartWidth = document.querySelector("#trendline-container").getBoundingClientRect().width;
      document.querySelector("#trendline-container svg").setAttribute("width", chartWidth);
      document.querySelector("#trendline-container svg").setAttribute("height", chartHeight);

      const times = [];
      for (let i = 0; i < state.numRuns; i++) {
        const start = performance.now();
        let res = predict(model);
        if (res instanceof Promise) {
          res = await res;
        }

        if (res instanceof tf.Tensor) {
          res.dataSync();
        }

        times.push(performance.now() - start);
      }

      const average = times.reduce((acc, curr) => acc + curr, 0) / times.length;
      const max = Math.max(...times);
      const min = Math.min(...times);
      const xIncrement = chartWidth / times.length;

      document.querySelector("#trendline-container #yMax").textContent = printTime(max);
      document.querySelector("#trendline-container path")
        .setAttribute("d", `M${times.map((d, i) => `${i * xIncrement},${chartHeight - (d / max) * chartHeight}`).join('L')}`);

      await showMsg(null);
      appendRow(timeTable, `Subsequent average (${state.numRuns} runs)`, printTime(average));
      appendRow(timeTable, 'Best time', printTime(min));
    }

    async function profileMemory() {
      await showMsg('Profile memory');
      const start = performance.now();
      let res;
      const data = await tf.profile(() => res = predict(model));
      if (res instanceof Promise) {
        res = await res;
      }

      if (res instanceof tf.Tensor) {
        res.dataSync();
      }
      const elapsed = performance.now() - start;
      await showMsg(null);
      appendRow(timeTable, 'Peak memory', printMemory(data.peakBytes));
      appendRow(timeTable, '2nd inference', printTime(elapsed));
    }

    function showKernelTime(kernels) {
      const tbody = document.querySelector('#kernels tbody');
      kernels.forEach(k => {
        const nameSpan = document.createElement('span');
        nameSpan.setAttribute('title', k.scopes.slice(0, -1).join(' --> '));
        nameSpan.textContent = k.scopes[k.scopes.length - 1];
        appendRow(tbody, nameSpan, k.time.toFixed(2), k.output, k.gpuProgramsInfo);
      });
    }

    async function profileKernelTime() {
      await showMsg('Profiling kernels');
      ENV.set('DEBUG', true);
      const oldLog = console.log;
      let kernels = [];
      console.log = msg => {
        let parts = [];
        if(typeof msg === 'string') {
          parts = msg.split('\t').map(x => x.slice(2));
        }

        if(parts.length > 2) {
          // heuristic for determining whether we've caught a profiler
          // log statement as opposed to a regular console.log
          // TODO(https://github.com/tensorflow/tfjs/issues/563): return timing information as part of tf.profile
          const scopes = parts[0].trim()
            .split('||')
            .filter(s => s !== 'unnamed scope');
          kernels.push({
            scopes: scopes,
            time: Number.parseFloat(parts[1]),
            output: parts[2].trim(),
            gpuProgramsInfo: parts[4]
          });
        } else {
          oldLog.call(oldLog, msg);
        }
      }
      let res = predict(model);
      if(res instanceof Promise) {
        res = await res;
      }

      if(res instanceof tf.Tensor) {
        res.dataSync();
      }

      await showMsg(null);
      await sleep(10);
      kernels = kernels.sort((a, b) => b.time - a.time);
      appendRow(timeTable, 'Number of kernels', kernels.length);
      showKernelTime(kernels);
      ENV.set('DEBUG', false);
      // Switch back to the old log;
      console.log = oldLog;
    }

    function queryTimerIsEnabled() {
      return ENV.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0;
    }

    async function onPageLoad() {
      var gui = new dat.gui.GUI();
      //gui.remember(obj);
      gui.add(state, 'numRuns').onChange(v => {
        measureAveragePredictTime();
      });
      showVersions();
      await showEnvironment();
      await loadAndRecordTime();
      await warmUpAndRecordTime();
      await showMsg('Waiting for GC');
      await sleep(1000);
      await profileMemory();
      await sleep(200);
      await measureAveragePredictTime();
      await sleep(200);
      if (queryTimerIsEnabled()) {
        await profileKernelTime();
      } else {
        showMsg('Skipping kernel times since query timer extension is not ' +
                'available. <br/> Use Chrome 70+.');
      };
    }
    onPageLoad();
  </script>
</body>
</html>
